package com.luo.requestWrapper;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.luo.util.XSSUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.StreamUtils;


import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.Charset;
import java.util.Set;


/**
* 
* 过滤@RequestBody实体类中的所有参数
* @author Zhiguang
* @since 2021/12/20 3:45
*/

@Slf4j
public class XssFilterServletRequestWrapper extends HttpServletRequestWrapper {
    private byte[] requestBody;
    private Charset charSet;

    public XssFilterServletRequestWrapper(HttpServletRequest request) {
        super(request);

        //缓存请求body
        try {

            String requestBodyStr = getRequestPostStr(request);
            JSONObject resultJson = JSON.parseObject(requestBodyStr);

            Set<String> keySet = resultJson.keySet();
            for (String key : keySet) {
                Object o = resultJson.get(key);
                if (o != null){
                    //过滤掉富文本内的非法标签
                    o = XSSUtil.jsoupCleanRichText(o.toString());
                }
                resultJson.put(key,o);
            }

            requestBody = resultJson.toString().getBytes(charSet);
        } catch (IOException e) {
            log.error("", e);
        }
    }

    public String getRequestPostStr(HttpServletRequest request)
            throws IOException {
        String charSetStr = request.getCharacterEncoding();
        if (charSetStr == null) {
            charSetStr = "UTF-8";
        }
        charSet = Charset.forName(charSetStr);

        return StreamUtils.copyToString(request.getInputStream(), charSet);
    }

    /**
     * 重写 getInputStream()
     */
    @Override
    public ServletInputStream getInputStream() {
        if (requestBody == null) {
            requestBody = new byte[0];
        }

        final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(requestBody);

        return new ServletInputStream() {
            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {

            }

            @Override
            public int read() {
                return byteArrayInputStream.read();
            }
        };
    }

    /**
     * 重写 getReader()
     */
    @Override
    public BufferedReader getReader() {
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }

}
